package com.xwder.example.filter;

import cn.hutool.core.util.StrUtil;
import com.xwder.example.advice.BizException;
import com.xwder.example.common.result.ResultCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * @author xwder
 * @date 2021/3/31 10:15
 **/
public class XsslHttpServletRequestWrapper extends HttpServletRequestWrapper {
    private static Logger logger = LoggerFactory.getLogger(XsslHttpServletRequestWrapper.class);

    HttpServletRequest xssRequest = null;
    private static final String XSS_ERROR_CODE = "XSS_ERROR_CODE";
    /**
     * 需要拦截的字符关键字、url编码
     */
    private static String[] safeless = {
            "<script",
            "</script",
            "<iframe",
            "</iframe",
            "<frame",
            "</frame",
            "set-cookie",
            "%3cscript",
            "%3c/script",
            "%3ciframe",
            "%3c/iframe",
            "%3cframe",
            "%3c/frame",
            "src=\"javascript:",
            "<body", "</body",
            "%3cbody",
            "%3c/body",
            //"<", ">",
            //"</", "/>",
            "%3c", "%3e", "%3c/", "/%3e"};

    /**
     * 存储body数据的容器
     */
    private byte[] body = null;
    private boolean isJsonPost = false;

    public XsslHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
        xssRequest = request;
        if (request.getContentType() != null && request.getContentType().contains("application/json")) {
            String bodyStr = getBodyString(request);
            if (bodyStr != null) {
                replaceXSS(bodyStr);
            }
            body = bodyStr.getBytes(Charset.defaultCharset());
            isJsonPost = true;
        }

    }

    /**
     * 获取请求Body
     *
     * @param request request
     * @return String
     */
    public String getBodyString(final ServletRequest request) {
        try {
            return inputStream2String(request.getInputStream());
        } catch (IOException e) {
            logger.error("", e);
            throw new RuntimeException(e);
        }
    }

    /**
     * 获取请求Body
     *
     * @return String
     */
    public String getBodyString() {
        if (this.getContentType() != null && this.getContentType().contains("application/json")) {
            final InputStream inputStream = new ByteArrayInputStream(body);
            return inputStream2String(inputStream);
        } else {
            return null;
        }
    }

    /**
     * 将inputStream里的数据读取出来并转换成字符串
     *
     * @param inputStream inputStream
     * @return String
     */
    private String inputStream2String(InputStream inputStream) {
        StringBuilder sb = new StringBuilder();
        BufferedReader reader = null;

        try {
            reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
            String line;
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            logger.error("", e);
            throw new RuntimeException(e);
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e) {
                    logger.error("", e);
                }
            }
        }

        return sb.toString();
    }

    @Override
    public BufferedReader getReader() throws IOException {
        if (isJsonPost) {
            return new BufferedReader(new InputStreamReader(getInputStream()));
        } else {
            return super.getReader();
        }
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        if (isJsonPost) {
            final ByteArrayInputStream inputStream = new ByteArrayInputStream(body);

            return new ServletInputStream() {
                @Override
                public int read() throws IOException {
                    return inputStream.read();
                }

                @Override
                public boolean isFinished() {
                    return false;
                }

                @Override
                public boolean isReady() {
                    return false;
                }

                @Override
                public void setReadListener(ReadListener readListener) {
                }
            };
        } else {
            return super.getInputStream();
        }
    }

    public HttpServletRequest getXssRequest() {
        return xssRequest;
    }

    @Override
    public String getParameter(String name) {
        String value = super.getParameter(replaceXSS(name));
        if (value != null) {
            value = replaceXSS(value);
        }
        return value;
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] values = super.getParameterValues(replaceXSS(name));
        if (values != null && values.length > 0) {
            for (int i = 0; i < values.length; i++) {
                values[i] = replaceXSS(values[i]);
            }
        }
        return values;
    }

    @Override
    public String getHeader(String name) {

        String value = super.getHeader(replaceXSS(name));
        if (value != null) {
            value = replaceXSS(value);
        }
        return value;
    }

    /**
     * 去除待带script、src的语句，转义替换后的value值
     */
    public String replaceXSS(String value) {
        if (value != null) {
            // Avoid anything between script tags
            Pattern scriptPattern = Pattern.compile("<script>(.*?)</script>", Pattern.CASE_INSENSITIVE);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "<script>(...)</script>");
            }

            // Avoid anything in a src='...' type of e­xpression
            scriptPattern = Pattern.compile("src[\r\n]*=[\r\n]*\\\'(.*?)\\\'", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "src=(...)");
            }

            scriptPattern = Pattern.compile("src[\r\n]*=[\r\n]*\\\"(.*?)\\\"", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "src=(...)");
            }

            // Remove any lonesome </script> tag
            scriptPattern = Pattern.compile("</script>", Pattern.CASE_INSENSITIVE);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "</script>");
            }

            // Remove any lonesome <script ...> tag
            scriptPattern = Pattern.compile("<script(.*?)>", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "<script(...)>");
            }

            // Avoid eval(...) e­xpressions
            scriptPattern = Pattern.compile("eval\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "eval(...)");
            }

            // Avoid e­xpression(...) e­xpressions
            scriptPattern = Pattern.compile("e­xpression\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "e-xpression(...)");
            }

            // Avoid javascript:... e­xpressions
            scriptPattern = Pattern.compile("javascript:", Pattern.CASE_INSENSITIVE);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "javascript:");
            }

            // Avoid alert:... e­xpressions
            scriptPattern = Pattern.compile("alert\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "alert(...)");
            }

            // Avoid onload= e­xpressions
            scriptPattern = Pattern.compile("onload(.*?)=", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "onload(...)");
            }

            scriptPattern = Pattern.compile("vbscript[\r\n| | ]*:[\r\n| | ]*", Pattern.CASE_INSENSITIVE);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "vbscript...:...");
            }
            // Avoid script:... e­xpressions
            scriptPattern = Pattern.compile("script:", Pattern.CASE_INSENSITIVE);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "javascript:");
            }

            scriptPattern = Pattern.compile("(<[^<>]*?(onload|onchange|onclick|onerror|src|prompt|alert|script|\\([\\s\\S]*\\))[^<>]*?>)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            if (scriptPattern.matcher(value).find()) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "< ()>:");
            }

            if (judgeSQLInject(value)) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "SQL inject...:...");
            }
            if (!isSafe(value)) {
                xssRequest.setAttribute(XSS_ERROR_CODE, "XSS...:...");
            }

            if (xssRequest != null && xssRequest.getAttribute(XSS_ERROR_CODE) != null) {
                String errorInfo = (String) xssRequest.getAttribute(XSS_ERROR_CODE);
                logger.error("该内容包含非法字符：" + value);
                List<String> msg = new ArrayList<>();
                msg.add(errorInfo);
                String exceptionMsg = "请求操作内容中含有非法字符：" + errorInfo;
                throw new BizException(ResultCode.PARAM_VALIDATE_FAILD.getCode(), exceptionMsg);
            }
        }
        return value;
    }

    /**
     * 判断参数是否含有攻击串
     *
     * @param value
     * @return
     */
    public boolean judgeSQLInject(String value) {
        if (StrUtil.isEmpty(value)) {
            return false;
        }
        String pattern = "\"operator\":\"[^\",]*\"";

        Pattern p = Pattern.compile(pattern);
        // 现在创建 matcher 对象
        Matcher m = p.matcher(value);
        if (m.find()) {
            value = m.replaceAll("");
        }

        String badWords = "\\W+(exec|execute|insert|select|delete|update|create|drop|chr|mid|master|truncate|char|declare|sitename|net user|xp_cmdshell|like'|table|from|grant|use|group_concat|column_name|count(\\*)|count\\(|information_schema.columns|table_schema|union|where|order by)\\W+";
        Pattern p2 = Pattern.compile(badWords, Pattern.CASE_INSENSITIVE);
        // 现在创建 matcher 对象
        Matcher m2 = p2.matcher(value);
        if (m2.find()) {
            return true;
        }

        String badChars = "(\\$|--|\')";
        Pattern p3 = Pattern.compile(badChars);
        // 现在创建 matcher 对象
        Matcher m3 = p3.matcher(value);
        if (m3.find()) {
            return true;
        }

        return false;
    }

    private static boolean isSafe(String str) {
        if (null != str && str.length() > 0) {
            String pattern = "\"operator\":\"[^\",]*\"";

            Pattern p = Pattern.compile(pattern);
            // 现在创建 matcher 对象
            Matcher m = p.matcher(str);
            if (m.find()) {
                str = m.replaceAll("");
            }

            for (String s : safeless) {
                if (str.toLowerCase().contains(s)) {
                    return false;
                }
            }
        }

        return true;
    }
}